{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from PIL import Image, ImageDraw\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from detector.anchor_generator import AnchorGenerator\n",
    "from detector.training_target_creation import create_targets, match_boxes\n",
    "from detector.input_pipeline import Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_boxes(image, boxes, color='red', fill=None):\n",
    "    \n",
    "    w, h = image.size\n",
    "    scaler = np.array([h, w, h, w], dtype='float32')\n",
    "    boxes = boxes.copy() * scaler\n",
    "    \n",
    "    image_copy = image.copy()\n",
    "    draw = ImageDraw.Draw(image_copy, 'RGBA')\n",
    "    \n",
    "    for box in boxes:\n",
    "        ymin, xmin, ymax, xmax = box\n",
    "        outline = color\n",
    "        draw.rectangle(\n",
    "            [(xmin, ymin), (xmax, ymax)],\n",
    "            fill=fill, outline=outline\n",
    "        )\n",
    "    return image_copy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create a graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "\n",
    "params = {\n",
    "    'min_dimension': 640,\n",
    "    'batch_size': 14,\n",
    "    'image_height': 640,\n",
    "    'image_width': 640,\n",
    "}\n",
    "pipeline = Pipeline(\n",
    "    ['/home/gpu1/datasets/COCO/coco_person/train_shards/shard-0000.tfrecords'],\n",
    "    is_training=True, params=params\n",
    ")\n",
    "iterator = pipeline.dataset.make_one_shot_iterator()\n",
    "features, labels = iterator.get_next()\n",
    "\n",
    "\n",
    "image = features['images'][0]\n",
    "image_width, image_height = tf.shape(image)[1], tf.shape(image)[0]\n",
    "anchor_generator = AnchorGenerator(\n",
    "    strides=[8, 16, 32, 64, 128],\n",
    "    scales=[32, 64, 128, 256, 512],\n",
    "    scale_multipliers=[1.0, 1.4142],\n",
    "    aspect_ratios=[1.0, 2.0, 0.5]\n",
    ")\n",
    "anchors = anchor_generator(image_height, image_width)\n",
    "\n",
    "# only the last level:\n",
    "# n = anchor_generator.num_anchors_per_feature_map[-1]\n",
    "# anchors = anchors[-n:]\n",
    "\n",
    "groundtruth_boxes = labels['boxes'][0]\n",
    "groundtruth_labels  = labels['labels'][0]\n",
    "\n",
    "N = tf.shape(groundtruth_boxes)[0]\n",
    "num_anchors = tf.shape(anchors)[0]\n",
    "only_background = tf.fill([num_anchors], -1)\n",
    "\n",
    "matches = tf.to_int32(tf.cond(\n",
    "    tf.greater(N, 0),\n",
    "    lambda: match_boxes(\n",
    "        anchors, groundtruth_boxes,\n",
    "        positives_threshold=0.5,\n",
    "        negatives_threshold=0.4,\n",
    "        force_match_groundtruth=True\n",
    "    ),\n",
    "    lambda: only_background\n",
    "))\n",
    "reg_targets, cls_targets = create_targets(\n",
    "    anchors, groundtruth_boxes,\n",
    "    groundtruth_labels, matches\n",
    ")      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Session() as sess:\n",
    "    result = sess.run({\n",
    "        'image': image, \n",
    "        'anchors': anchors,\n",
    "        'matches': matches, \n",
    "        'boxes': groundtruth_boxes\n",
    "    })\n",
    "matched = np.where(result['matches'] >= 0)[0]\n",
    "ignored = np.where(result['matches'] == -2)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Draw"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "i = Image.fromarray((255.0*result['image']).astype('uint8'))\n",
    "i = draw_boxes(i, result['anchors'][matched], color='blue', fill=(0, 0, 255, 5))\n",
    "i = draw_boxes(i, result['boxes'], color='red', fill=(255, 0, 0, 50))\n",
    "i = draw_boxes(i, result['anchors'][ignored], color='yellow')\n",
    "i"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
